from FSaug import *


class FSAugTrain(object):
    def __init__(self):
        self.augment = Compose([
            Resize(size=(336, 336), select=[0, 2, 5, 6, 7]),
            RandomUpperCrop(size=(336, 336), select=[1, 3, 4]),
            RandomHflip(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, image, attr_idx):
        return self.augment(image, attr_idx)


class FSAugVal(object):
    def __init__(self):
        self.augment = Compose([
            Resize(size=(336, 336), select=[0, 2, 5, 6, 7]),
            UpperCrop(size=(336, 336), select=[1, 3, 4]),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, image, attr_idx):
        return self.augment(image, attr_idx)


